{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "import torch.utils.data\n",
    "import numpy as np\n",
    "from sklearn.cluster import KMeans\n",
    "import math\n",
    "from copy import deepcopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.1626, 0.2076, 0.4857, 0.8813],\n",
      "        [0.8368, 0.0530, 0.1879, 0.5520],\n",
      "        [0.5488, 0.7655, 0.1304, 0.3950]])\n",
      "tensor([[0.1624, 0.2079, 0.4873, 0.8283],\n",
      "        [0.8283, 0.0520, 0.1884, 0.5522],\n",
      "        [0.5490, 0.7666, 0.1299, 0.3963]])\n"
     ]
    }
   ],
   "source": [
    "def quantize_tensor(x, num_bits=8):\n",
    "    qmin = 0.\n",
    "    qmax = 2.**num_bits - 1.\n",
    "    min_val, max_val = x.min(), x.max()\n",
    "\n",
    "    scale = (max_val - min_val) / (qmax - qmin)\n",
    "\n",
    "    initial_zero_point = qmin - min_val / scale\n",
    "\n",
    "    zero_point = 0\n",
    "    if initial_zero_point < qmin:\n",
    "        zero_point = qmin\n",
    "    elif initial_zero_point > qmax:\n",
    "        zero_point = qmax\n",
    "    else:\n",
    "        zero_point = initial_zero_point\n",
    "\n",
    "    zero_point = int(zero_point)\n",
    "    q_x = zero_point + x / scale\n",
    "    q_x.clamp_(qmin, qmax).round_()\n",
    "    q_x = q_x.round().byte()\n",
    "    return q_x, scale, zero_point\n",
    "\n",
    "\n",
    "def dequantize_tensor(q_x, scale, zero_point):\n",
    "    return scale * (q_x.float() - zero_point)\n",
    "\n",
    "w = torch.rand(3,4)\n",
    "print(w)\n",
    "\n",
    "print(dequantize_tensor(*quantize_tensor(w)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "class QuantLinear(nn.Linear):\n",
    "    def __init__(self, in_features, out_features, bias=True):\n",
    "        super(QuantLinear, self).__init__(in_features, out_features, bias)\n",
    "        self.quant_flag = False\n",
    "        self.scale = None\n",
    "        self.zero_point = None\n",
    "    \n",
    "    def linear_quant(self, quantize_bit=8):\n",
    "        self.weight.data, self.scale, self.zero_point = quantize_tensor(self.weight.data, num_bits=quantize_bit)\n",
    "        self.quant_flag = True\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.quant_flag == True:\n",
    "            weight = dequantize_tensor(self.weight, self.scale, self.zero_point)\n",
    "            return F.linear(x, weight, self.bias)\n",
    "        else:\n",
    "            return F.linear(x, self.weight, self.bias)\n",
    "        \n",
    "            \n",
    "class QuantConv2d(nn.Conv2d):\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n",
    "                 padding=0, dilation=1, groups=1, bias=True):\n",
    "        super(QuantConv2d, self).__init__(in_channels, out_channels, \n",
    "            kernel_size, stride, padding, dilation, groups, bias)\n",
    "        self.quant_flag = False\n",
    "        self.scale = None\n",
    "        self.zero_point = None\n",
    "    \n",
    "    def linear_quant(self, quantize_bit=8):\n",
    "        self.weight.data, self.scale, self.zero_point = quantize_tensor(self.weight.data, num_bits=quantize_bit)\n",
    "        self.quant_flag = True\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if self.quant_flag == True:\n",
    "            weight = dequantize_tensor(self.weight, self.scale, self.zero_point)\n",
    "            return F.conv2d(x, weight, self.bias, self.stride,\n",
    "                        self.padding, self.dilation, self.groups)\n",
    "        else:\n",
    "            return F.conv2d(x, self.weight, self.bias, self.stride,\n",
    "                        self.padding, self.dilation, self.groups)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ConvNet, self).__init__()\n",
    "\n",
    "        self.conv1 = QuantConv2d(1, 32, kernel_size=3, padding=1, stride=1)\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "        self.maxpool1 = nn.MaxPool2d(2)\n",
    "\n",
    "        self.conv2 = QuantConv2d(32, 64, kernel_size=3, padding=1, stride=1)\n",
    "        self.relu2 = nn.ReLU(inplace=True)\n",
    "        self.maxpool2 = nn.MaxPool2d(2)\n",
    "\n",
    "        self.conv3 = QuantConv2d(64, 64, kernel_size=3, padding=1, stride=1)\n",
    "        self.relu3 = nn.ReLU(inplace=True)\n",
    "\n",
    "        self.linear1 = QuantLinear(7*7*64, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.maxpool1(self.relu1(self.conv1(x)))\n",
    "        out = self.maxpool2(self.relu2(self.conv2(out)))\n",
    "        out = self.relu3(self.conv3(out))\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.linear1(out)\n",
    "        return out\n",
    "\n",
    "    def linear_quant(self, quantize_bit=8):\n",
    "        # Should be a less manual way to quantize\n",
    "        # Leave it for the future\n",
    "        self.conv1.linear_quant(quantize_bit)\n",
    "        self.conv2.linear_quant(quantize_bit)\n",
    "        self.conv3.linear_quant(quantize_bit)\n",
    "        self.linear1.linear_quant(quantize_bit)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    total = 0\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.cross_entropy(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        total += len(data)\n",
    "        progress = math.ceil(batch_idx / len(train_loader) * 50)\n",
    "        print(\"\\rTrain epoch %d: %d/%d, [%-51s] %d%%\" %\n",
    "              (epoch, total, len(train_loader.dataset),\n",
    "               '-' * progress + '>', progress * 2), end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))\n",
    "    return test_loss, correct / len(test_loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    epochs = 2\n",
    "    batch_size = 64\n",
    "    torch.manual_seed(0)\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data/MNIST', train=True, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=batch_size, shuffle=True)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data/MNIST', train=False, download=False, transform=transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.1307,), (0.3081,))\n",
    "        ])),\n",
    "        batch_size=1000, shuffle=True)\n",
    "\n",
    "    model = ConvNet().to(device)\n",
    "    optimizer = torch.optim.Adadelta(model.parameters())\n",
    "    \n",
    "    for epoch in range(1, epochs + 1):\n",
    "        train(model, device, train_loader, optimizer, epoch)\n",
    "        _, acc = test(model, device, test_loader)\n",
    "    \n",
    "    quant_model = deepcopy(model)\n",
    "    print('\\n')\n",
    "    print('=='*10)\n",
    "    print('4 linear bits quantization')\n",
    "    quant_model.linear_quant(quantize_bit=4)\n",
    "    _, acc = test(quant_model, device, test_loader)\n",
    "    \n",
    "    return model, quant_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%\n",
      "Test: average loss: 0.0471, accuracy: 9847/10000 (98%)\n",
      "Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%\n",
      "Test: average loss: 0.0252, accuracy: 9917/10000 (99%)\n",
      "\n",
      "\n",
      "====================\n",
      "4 linear bits quantization\n",
      "\n",
      "Test: average loss: 0.0268, accuracy: 9916/10000 (99%)\n"
     ]
    }
   ],
   "source": [
    "model, quant_model = main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_weights(model):\n",
    "    modules = [module for module in model.modules()]\n",
    "    num_sub_plot = 0\n",
    "    for i, layer in enumerate(modules):\n",
    "        if hasattr(layer, 'weight'):\n",
    "            plt.subplot(221+num_sub_plot)\n",
    "            w = layer.weight.data\n",
    "            w_one_dim = w.cpu().numpy().flatten()\n",
    "            plt.hist(w_one_dim, bins=50)\n",
    "            num_sub_plot += 1\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAXGElEQVR4nO3df6xcZZ3H8fdnW4F1XW1LW6xAuZCtWdlsLOYGyGpEBbFgYvkDFHZZi5I0UUyWsJt4XUxwUWPxj3UxumijXYq7S/nhEm6oypYCMUus0iqWXwstiPam3Ra2iBKyaPG7f5zn4vR25s7c+XFmznk+r2QyZ545M/M9M8/5zjPPeeY5igjMzCwvfzDsAMzMrHxO/mZmGXLyNzPLkJO/mVmGnPzNzDI0f9gBzGbx4sUxNjY27DCsxnbs2PFcRCwp+3Vdt22QOqnXI538x8bG2L59+7DDsBqT9PNhvK7rtg1SJ/Xa3T5mZhly8jczy5CTv5lZhka6z79uxiY2A/DMuvf37bn69XyDft6yX8OqxXWifG75m5llyMnfzEbK2MTmw34J2GA4+ZuZZcjJ38wsQ07+ZmYZcvI3M8uQk7+ZWYY8zn/Auh210M1/AlqNlfYYajObyS1/M7MMueVvZgPVyS/Sbp/Duufkb2al6TbhW/+528fMLEMDSf6SNkg6IOmRhrJFkrZI2pWuFw7itc069dGPfhTgrZ3UUxW+LGm3pJ2S3tbwmDVp/V2S1pS/JXmZnv7Bvwx6M6iW/43AqhllE8DWiFgBbE23zYbmsssuA9g1o7hVPT0PWJEua4EboPiyAK4BzgBOB65xw8aqYCDJPyK+DxycUbwa2JiWNwIXDOK1zTr1zne+E+DQjOJW9XQ1cFMUtgELJC0D3gdsiYiDEfE8sIUjGz5mI6fMA77HRcQ+gIjYJ2lps5UkraVoWbF8+fISw+tNP0YjlPkz1qMnWmpVT48H9jSsN5XKWpUfoap12+pp5A74RsT6iBiPiPElS2Y9+bxZmdSkLGYpP7LQddtGSJnJf3/6mUy6PlDia5t1qlU9nQJObFjvBGDvLOVmI63M5D8JTI+EWAPcWeJrm3WqVT2dBD6cRv2cCbyQuofuBs6VtDAd6D03lVkJPOqnewPp85d0M/AuYLGkKYrREOuAWyVdDvwCuGgQr23WqUsuuQTgTylGcrarp98Bzgd2Ay8BHwGIiIOSPgs8mNa7NiJmDnYwGzkDSf4RcUmLu84exOuZdePmm29m06ZNOyNifMZdR9TTiAjgimbPExEbgA0DCNFsYDy9g5kNhLtjRtvIjfYxM7PBc8t/AEahxeMJtMxsNm75m5llyMnfzCxD7vYxs8rzdCVz55a/mVmGnPzNzDLk5G9mliH3+Q/BXPonO1l3WEM13c9qVl1u+ZuZZcjJ38wsQ+72MbO+8b/Fq8MtfzOzDDn5m5llyN0+ZlYrHoXWGbf8zcwy5JZ/D+p+cKtZC6rf2zz9fG6hmZXLLX8zsww5+ZuZZcjJ38wsQ07+ZmYZcvI3M8uQR/uYWW15zH9rTv5m1pO6D3muKyf/Do1CC2KUd7J2/wlwq8tstLjP38wsQ07+ZmYZcvI3M8uQk7+ZWYZKP+Ar6Rng18ArwKGIGC87BrN2mtVTSYuAW4Ax4BnggxHxvCQB1wPnAy8Bl0XEj4cRd1lGefCBdWZYLf93R8RKJ34bcTPr6QSwNSJWAFvTbYDzgBXpsha4ofRIzebI3T5mnVsNbEzLG4ELGspvisI2YIGkZcMI0KxTwxjnH8B/Sgrg6xGxvvFOSWspWk8sX758COH9nn/aDt8Q/yvQrJ4eFxH7ACJin6Slad3jgT0Nj51KZfsan3CU6rbZMJL/2yNib9pxtkj674j4/vSdaSdbDzA+Ph5DiM8MmtTTWdZVk7Ij6q7r9nD5xEGHK73bJyL2pusDwB3A6WXHYNZOi3q6f7o7J10fSKtPASc2PPwEYG950ZrNXanJX9IfSfrj6WXgXOCRMmMwa2eWejoJrEmrrQHuTMuTwIdVOBN4Ybp7yGxUld3tcxxwRzEyjvnAv0fE90qOwaydpvVU0oPArZIuB34BXJTW/w7FMM/dFEM9P1J+yGZzU2ryj4ingbeW+Zpmc9WqnkbE/wJnNykP4IoSQhs6D4KoDw/1NDPLkJO/mVmGPJ9/MqxhYFX5Gd1rnO0e38l4fp8fwKx/3PI3M8uQW/5mlhX/giy45W9mliG3/M0sWzn/CnDL38wsQ07+ZmYZcvI3M8uQ+/xn6GQ8e1XG5jdTRuxVfn/McuGWv5lZhtzyN7NZ+ZdcPTn5m5mR37BPJ38zO4Jb+/XnPn8zsww5+ZuZZcjJ38wsQ1n0+ed2IKfq5tLf7M/WrDtZJH8za88HefPibh8zsxnGJjbX/svQLX+zzNU9yVlzbvmbmWXILX8zsxbqPKDALX8zswxVtuXf7Tdynb/J7ff8OZvNrrLJ38ysTHVrUDj5m2XII3x6U4cvAid/s0w44VsjJ38zsx5U9VdA6clf0irgemAe8I2IWFd2DGb9Nqr12q19a6XU5C9pHvBV4L3AFPCgpMmIeKzMOMz6aVTq9XSir1Lrs26afdmO6udRdsv/dGB3RDwNIGkTsBpw8rcqG6l67db+aGn1eQz7S6Hs5H88sKfh9hRwRuMKktYCa9PNFyU90e5JdR2LgefmGoyum+sjOtZVPANUu3iafXatPs82n/NJvcSRtK3X0F3dbmLUPssy1HKb29TLXre5bb0uO/mrSVkcdiNiPbB+Tk8qbY+I8V4C6yfHM7tRi6cP2tZr6K5uH/FC9Xvv2vI2D0bZ0ztMASc23D4B2FtyDGb95nptlVN28n8QWCHpZElHARcDkyXHYNZvrtdWOaV2+0TEIUmfAO6mGBK3ISIe7cNT9/RTegAcz+xGLZ6eDLBeN1Or965D3uYBUMQRXZNmZlZzntLZzCxDTv5mZhmqTPKXtEjSFkm70vXCFuu9IumhdJlsKD9Z0g/T429JB+YGGo+klZJ+IOlRSTslfajhvhsl/awh1pVdxrFK0hOSdkuaaHL/0Wl7d6ftH2u471Op/AlJ7+vm9buI5ypJj6X3Y6ukkxrua/rZ5azXelYlvdTlKuplX+mLiKjEBfgiMJGWJ4DrWqz3YovyW4GL0/LXgI8NOh7gzcCKtPwmYB+wIN2+EbiwxxjmAU8BpwBHAT8FTp2xzseBr6Xli4Fb0vKpaf2jgZPT88wrIZ53A69Nyx+bjme2zy7nS6/1rCqXXupyFS+97iv9uFSm5U/xd/mNaXkjcEGnD5Qk4D3A7d08vtt4IuLJiNiVlvcCB4AlPb5uo1enFYiI3wDT0wq0ivN24Oz0fqwGNkXEyxHxM2B3er6BxhMR90XES+nmNoox8dbaKNSzMvRSl6to6PtKlZL/cRGxDyBdL22x3jGStkvaJml6RzkW+GVEHEq3pyj+kl9GPABIOp3iG/6phuLPp590X5J0dBcxNJtWYOZ2vbpO2v4XKN6PTh47iHgaXQ58t+F2s88ud/2oZ1XQS12uol73lZ6N1Hz+ku4B3tjkrqvn8DTLI2KvpFOAeyU9DPyqyXptx7j2KR4kLQO+BayJiN+l4k8B/0Oxo64HPglcO5fnpbNpBVqt09GUBAOIp1hRuhQYB85qKD7is4uIqiWxORtwPauKXupyFfW6r/RspJJ/RJzT6j5J+yUti4h9qZIfaPEce9P105LuB04Dvg0skDQ/tRg6+vt9P+KR9HpgM/DpiNjW8Nz70uLLkv4F+Lt28TTRybQC0+tMSZoPvAE42OFjBxEPks6hSGxnRcTL0+UtPrvaJ/9B1rMK6aUuV1FP+0o/VKnbZxJYk5bXAHfOXEHSwunuE0mLgbcDj0VxxOQ+4MLZHj+AeI4C7gBuiojbZty3LF2Loh/3kS5i6GRagcY4LwTuTe/HJHBxGkFxMrAC+FEXMcwpHkmnAV8HPhARBxrKm352PcZTBz3VswrppS5XUdf7St8M+6h3pxeKvr2twK50vSiVj1OcOQngL4CHKY6cPwxc3vD4UyiS227gNuDoEuK5FPgt8FDDZWW6794U4yPAvwKv6zKO84EnKVrIV6eyaykqDMAxaXt3p+0/peGxV6fHPQGc16fPqV089wD7G96PyXafXc6XXutZlS691OUqXrrdV/p18fQOZmYZqlK3j5mZ9YmTv5lZhpz8zcwyNFJDPWdavHhxjI2NDTsMq7EdO3Y8FxGl/xvWddsGqZN6PdLJf2xsjO3btw87DKsxST8fxuu6btsgdVKv3e1jZpahtslf0jGSfiTpp2nK2H9I5SeryRTJs027qgFMIWxmZnPXScv/ZeA9EfFWYCWwStKZwHXAlyJiBfA8xcRDpOvnI+JPgC+l9ZB0KsW/2P4MWAX8s6R5/dwYMzPrTNs+/yj+BfZiuvmadAmKKZL/MpVvBD4D3EAxLelnUvntwFdmTiEM/EzS9BTCP+jHhuRgbGJz0/Jn1r2/5EjMqqfZ/pPzvtNRn7+keZIeophUagvF35FbTZHc0xTCktamaX23P/vss3PfogyNTWxu+cVgZtZMR8k/Il6JiJUUM8+dDryl2WrpuqcphCNifUSMR8T4kiVVOx+FmVk1zGm0T0T8ErgfOJM0RXK6q3E60lenKi1hCmEzM+tCJ6N9lkhakJb/EDgHeJzWUySXOYWwWc8k/UTSXWnZo9gsC520/JcB90naSTEH9ZaIuIvizFNXpQO3xwLfTOt/Ezg2lV9FcdJpIuJRipOoPwZ8D7giIl7p58aYdeE4isbMNI9isyx0MtpnJ8UZlWaWP02TE35HxP8BF7V4rs8Dn597mGb9NzU1BUW35DcoGjLCo9gsEyM9vYPNTeOIn5yHsHXqyiuvhOJY1PT5bo+lw1FskhpHsTWeNrHlibglrQXWAixfvrxv22HWDSf/EechnINx1113sXTpUoCXGopnG5HW0yg2KEayAesBxsfHfRYlGyonf8vSAw88wOTkJMCfA5uA1wP/RBrFllr/zUaxzTx5uEexWSV5YjfL0he+8IXpPv+HKQ7Y3hsRf4VHsVkm3PI3O9wngU2SPgf8hMNHsX0rHdA9SPGFQUQ8Kml6FNshPIrNKsLJ37IXEfdT/HnRo9gsG+72MTPLkJO/mVmGnPzNzDLkPn8zy1bOf4x0y9/MLENu+ddUzi0aM2vPyX9EeVoHMxskd/uYmWXILX8zqy3/gm7NyT8D7v83s5nc7WNmliEnfzOzDDn5m5llyMnfzCxDTv5mZhly8jczy1Db5C/pREn3SXpc0qOS/iaVL5K0RdKudL0wlUvSlyXtlrRT0tsanmtNWn+XpDWtXtPMzAark5b/IeBvI+ItwJnAFZJOBSaArRGxAtiabgOcR3Ee0xXAWuAGKL4sgGuAMyjOlHTN9BeGmZmVq23yj4h9EfHjtPxr4HHgeGA1sDGtthG4IC2vBm6KwjZggaRlwPuALRFxMCKeB7YAq/q6NdbW2MRm/+vRzObW5y9pDDgN+CFwXETsg+ILAliaVjse2NPwsKlU1qp85muslbRd0vZnn312LuGZmVmHOk7+kl4HfBu4MiJ+NduqTcpilvLDCyLWR8R4RIwvWbKk0/DMzGwOOprbR9JrKBL/v0XEf6Ti/ZKWRcS+1K1zIJVPASc2PPwEYG8qf9eM8vu7D71+3B1jZmXpZLSPgG8Cj0fEPzbcNQlMj9hZA9zZUP7hNOrnTOCF1C10N3CupIXpQO+5qczMzErWScv/7cBfAw9LeiiV/T2wDrhV0uXAL4CL0n3fAc4HdgMvAR8BiIiDkj4LPJjWuzYiDvZlK8zMbE7aJv+I+C+a99cDnN1k/QCuaPFcG4ANcwnQzMz6z/P5m1mt+NhZZzy9g5lZhpz8zcwy5ORvZpYhJ38zsww5+ZuZZcijfczMOHyU0DPr3j/ESMrhlr+ZWYbc8s9Ubq0cMzuck/+Q+Q8pw7Nnzx6AN0t6HPgdsD4irk8nHroFGAOeAT4YEc+nea6up5i+5CXgsulzXaQz0306PfXnImIjZiPM3T6Wrfnz5wNM+Sx1liMnf8vWsmXLoGjB+yx1lh0nfzN8ljrLj5O/Zc9nqbMc+YCv5U74LHWV54ETc+eWv2WrOPUEJ+Gz1FmG3PK3bD3wwAMAxwLv8VnqLDdO/patd7zjHQA7ImK8yd0+S53Vmrt9zMwy5ORvZpYhd/sMgUcmmNmwueVvZpahtslf0gZJByQ90lC2SNIWSbvS9cJULklflrRb0k5Jb2t4zJq0/q40CZaZmQ1JJ90+NwJfAW5qKJue+GqdpIl0+5McPvHVGRQTX53RMPHVOMU/H3dImkzzoNiQeXpnqyp3oXavbcs/Ir4PzByz7ImvzMwqrNs+/4FMfAWe/MrMrAz9PuDb08RX4MmvzGz4xiY2175Lqdvkvz915zCHia+alZuZ2RB0m/w98ZWZWYW1He0j6WaK6WoXS5qiGLXjia/MzCqsbfKPiEta3OWJr+ao7n2IZlYdnt7BDuMx/2Z5cPI3s0rxL+j+8Nw+ZmYZcvI3M8uQk7+ZWYac/M3MMuQDvgPmg1NmNoqc/K0lD/u0UTGsRlSd9wF3+5iZZcjJ38wsQ+72GYA69vNPb1Pdfvqa5cotfzOzDDn5m5llyN0+ZjaSRq37tG4jf5z8bU7qtgOY5crJv09GrZViZjYbJ38zGyluSJXDyd+65i4gs+py8jezoataa78ODR8n/x5UrcKamU1z8u+Ck/6R6tASsnLVZT+qat138re+81QQ1kpdEn4dlJ78Ja0CrgfmAd+IiHVlx9ANV9q5q2qLqBtVrddlyGnfqVLDp9TkL2ke8FXgvcAU8KCkyYh4rMw4OpVTpR20On8RVK1elyXn/acK9b3slv/pwO6IeBpA0iZgNTDQnSTnSjiK2n0eo7qzzGIo9XqYvE91rtV7Nex6XnbyPx7Y03B7CjijcQVJa4G16eaLkp7ow+suBp7rw/NUQeW3Vdd1vGo/tvWkHh8PHdRrGFjdHpbK1zOGvA1zqOezabUNbet12clfTcrisBsR64H1fX1RaXtEjPfzOUeVt3Uo2tZrGEzdHpYReu+7lvs2lD2l8xRwYsPtE4C9Jcdg1m+u11Y5ZSf/B4EVkk6WdBRwMTBZcgxm/eZ6bZVTardPRByS9AngboohcRsi4tESXroWP7U75G0t2RDr9TCNxHvfo6y3QRFHdE2amVnN+TSOZmYZcvI3M8tQLZO/pEWStkjala4XNllnpaQfSHpU0k5JHxpGrN2QtErSE5J2S5pocv/Rkm5J9/9Q0lj5UfZPB9t7laTH0ue4VVI/xu5bgyrvU3XYXwayD0RE7S7AF4GJtDwBXNdknTcDK9Lym4B9wIJhx97Bts0DngJOAY4CfgqcOmOdjwNfS8sXA7cMO+4Bb++7gdem5Y9VeXtH9VLVfaoO+8ug9oFatvwp/lq/MS1vBC6YuUJEPBkRu9LyXuAAsKS0CLv36lQCEfEbYHoqgUaN2387cLakZn9EqoK22xsR90XES+nmNopx9tZfVd2n6rC/DGQfqGvyPy4i9gGk66WzrSzpdIpv1KdKiK1XzaYSOL7VOhFxCHgBOLaU6Pqvk+1tdDnw3YFGlKeq7lN12F8Gsg9Udj5/SfcAb2xy19VzfJ5lwLeANRHxu37ENmCdTCXQ0XQDFdHxtki6FBgHzhpoRDVV032qDvvLQPaByib/iDin1X2S9ktaFhH7UkU80GK91wObgU9HxLYBhdpvnUwlML3OlKT5wBuAg+WE13cdTZ0g6RyKJHVWRLxcUmy1UtN9qg77y0D2gbp2+0wCa9LyGuDOmSukv+HfAdwUEbeVGFuvOplKoHH7LwTujXQkqILabq+k04CvAx+IiKZJyXpW1X2qDvvLYPaBYR/JHtDR8WOBrcCudL0olY9TnGUJ4FLgt8BDDZeVw469w+07H3iSoj/16lR2bfrgAY4BbgN2Az8CThl2zAPe3nuA/Q2f4+SwY67bpcr7VB32l0HsA57ewcwsQ3Xt9jEzs1k4+ZuZZcjJ38wsQ07+ZmYZcvI3M8uQk7+ZWYac/M3MMvT/H4NEhD5I+v0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_weights(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAW+0lEQVR4nO3dbYxcV33H8e+vhqSU0sbGD3VthzXUBEIKJl0lkYIiII3jBAkHiSD7RbIhQUaVLYHEm6VvzEMjuRUQNRKNasoqTgU2LhDFIm7C4hIhEA9e0+DECZaXYOLFlu10Q6BEQB3+fXHPmtn1zHrWnntn7pzfR1rt3HOfzl0f/+fcc885VxGBmZnl4Y+6nQEzM6uOg76ZWUYc9M3MMuKgb2aWEQd9M7OMvKzbGZjNwoULY2BgoNvZsD62f//+5yJiUdXnddm2Ms1Wrns66A8MDDA2NtbtbFgfk/SzbpzXZdvKNFu5dvOOmVlGHPTNzDLioG9mlpFS2vQl/THwLeDidI4vR8QWSSuBncAC4IfAbRHxuzLyULWB4YfPfD6y9V0dO16rY3XyfJ3Ou1krLmvdV1ZN/7fAOyPiLcBqYK2ka4B/BO6JiFXA88BdJZ3fzGpqYPjhaV8O1lmlBP0o/G9afHn6CeCdwJdT+nbgljLOb2ZmzZXWpi9pnqTHgZPAKPAT4BcRcTptMgEsa7LfRkljksZOnTpVVvbMzLJUWtCPiJciYjWwHLgKeGOzzZrsty0iBiNicNGiysfMmJn1tdJ770TEL4DHgGuASyRNPTxeDhwr+/xmZvYHZfXeWQT8X0T8QtIrgL+leIj7TeC9FD14hoCHyji/uZeEmTVX1jQMS4HtkuZR3E3sioivSXoK2CnpH4D/Bj5f0vnNzKyJUoJ+RBwA3tok/RmK9n0zM+sCj8itkPsfm1m3OeibmWXEQd/MLCMO+mZmGXHQNzPLiIO+mVlGHPQtW7/5zW8A3ijpR5IOSvo4gKSVkr4v6bCkL0m6KKVfnJbH0/qBqWNJ+mhKPyTpxm5cj1k7HPQtWxdffDHAoTlMAX4X8HxE/BVwT9oOSZcD64E3AWuBf0kDE816joO+ZUsSwO/TYjtTgK9Ly6T116s4yDpgZ0T8NiJ+CozjQYilmRrv4jEv58dB37I3hynAlwFHAdL6F4BXN6Y32afxPJ423LrOQd+yN4cpwNViXav0mefxtOHWdQ76ZrQ9BfgEsAIgrf9zYLIxvck+Zj3FQd8q1yttsqmJZR5AwxTgT/OHKcBh+hTgu9Myaf1/RUSk9PWpd89KYBXwgyquwWyuyppa2aznHT9+HOAySQdobwrwzwP/Lmmcooa/HiAiDkraBTwFnAY2RcRL1V6NWXsc9C1bb37zmwGeiojBxvRWU4BHxG+AW5sdKyLuBu4uIZs9zy/sqRc375iZZcRB38wsIw76ZmYZcdA3M8uIg76ZWUYc9M3MMlJKl01JK4AHgL+gmNBqW0T8s6QFwJeAAeAI8L6IeL6MPFi13G3PrB7KqumfBj4SEW+kGNa+KU0/OwzsTVPW7k3LZmZWkVKCfkQcj4gfps+/ohjavozpU9M2TllrZmYVKL1NP71d6K3A94ElEXEcii8GYHGT7T39bA/qhblyzOzClRr0Jf0p8BXgwxHxy3b28fSzZmblKS3oS3o5RcD/QkR8NSWfkLQ0rV9K8eIKMzOrSClBP71C7vPA0xHxmYZVjVPTNk5Za2ZmFShrls1rgduAJ9Kr6AD+HtgK7JJ0F/AsLWYsNDOzcpQS9CPi2zR/hRzA9WWc08zMzs0jcs3MMuKgb9k6evQowOslPS3poKQPAUhaIGlU0uH0e35Kl6R7JY1LOiDpyqljSRpK2x+WNNT8jGbd56Bv2XrZy14GMDGHkeM3Ubz/dhWwEbgPii8JYAtwNcUbt7ZMfVGY9RoHfcvW0qVLAV6EtkeOrwMeiML3gEtS1+MbgdGImExzSY0Cayu7ELM5cNA3o+2R48uAow27TaS0Vukzz+HR5tZ1DvqWvTmMHG/WIy1mSZ+e4NHm1gMc9C13ov2R4xPAioZ9lwPHZkk36zkO+patiAB4De2PHN8N3J568VwDvJCafx4F1kianx7grklpZj2nrBG5Zj3vO9/5DsCrgXe2OXJ8D3AzME7xAPj9ABExKemTwL603SciYrKSizCbIwd9y9bb3vY2gP0RMdhk9Vkjx6O4NdjU7FgRMQKMdDSDZiVw0G+TXwdoZv3AQd/M+pYra2fzg1wzs4w46JuZZcRB38wsIw76ZmYZcdA3M8uIg76ZWUYc9M3MMuJ++tZz3LfarDwO+mY2K38J9xc375iZZaSUoC9pRNJJSU82pDV92bSZmVWnrJr+/Zz9jtBWL5s2M7OKlBL0I+JbwMz5xFu9bNrMzCpSZZt+q5dNT+OXR5uZlafnHuT65dFmZuWpMui3etm0WVfceeedAG9pp8NBei/uvZLGJR2QdGXDPkNp+8OShs4+k1nvqDLot3rZtFlX3HHHHQCHZyS36nBwE7Aq/WwE7oPiSwLYAlwNXAVscc8062VlddncAXwXuEzSRHrB9FbgBkmHgRvSslnXXHfddQCnZyS36nCwDnggCt8DLkl3rDcCoxExGRHPA6Oc3XPNrGeUMiI3Ija0WHXWy6bNesy0DgeSpjocLAOONmw3kdJapZ9F0kaKuwQuvfTSDmfbrD099yDXrEepSVrMkn52ojspWA9w0DebrlWHgwlgRcN2y4Fjs6Sb9SQHfbPpWnU42A3cnnrxXAO8kJqBHgXWSJqfHuCuSWlmPcmzbFq2NmzYAPAGih6ZExS9cLYCu1Lng2eBW9Pme4CbgXHgReD9ABExKemTwL603SciYuZodOtROc4g6qBv2dqxYwc7d+48EBGDM1ad1eEgIgLY1Ow4ETECjJSQRbOOc/OOmVlGsq7p53hrZ2Z5c03fzCwjDvpmZhlx0DfL3MDww9OaOq2/OeibmWXEQd/MLCMO+lZLbpIwOz8O+mZmGXHQNzPLSF8HfTcBmJlN19dB38zsQvVb5dFB38wsIw76ZmYZcdA3M8uIg75ZH+q3dmjrnMqnVpa0FvhnYB7wbxGx9XyO42mRbTZVl49OlWuzslVa05c0D/gscBNwObBB0uVV5sGs01yubUod7rCqrulfBYxHxDMAknYC64CnKs6HWSdVWq59l2sXQsWrPys6mfReYG1EfCAt3wZcHRGbG7bZCGxMi5cBh1ocbiHwXInZLUtd8w31zfts+X5NRCy6kIO3U65Tust2b+rHfLcs11XX9NUkbdq3TkRsA7ad80DSWJMXWve8uuYb6pv3CvJ9znINLtu9Krd8V917ZwJY0bC8HDhWcR7MOs3l2mqj6qC/D1glaaWki4D1wO6K82DWaS7XVhuVNu9ExGlJm4FHKbq2jUTEwfM83Dlvk3tUXfMN9c17qfnucLkG/52rllW+K32Qa2Zm3eURuWZmGXHQNzPLSC2DvqS1kg5JGpc03O38tEvSEUlPSHpc0li389OKpBFJJyU92ZC2QNKopMPp9/xu5rGZFvn+mKSfp7/545Ju7mYeZ1PXcg0u22XrZNmuXdDvgyHv74iI1T3eL/h+YO2MtGFgb0SsAvam5V5zP2fnG+Ce9DdfHRF7Ks5TW/qgXIPLdpnup0Nlu3ZBn4Yh7xHxO2BqyLt1SER8C5ickbwO2J4+bwduqTRTbWiR77pwua6Ay3Y9g/4y4GjD8kRKq4MAvi5pfxqSXydLIuI4QPq9uMv5mYvNkg6kW+Seu3VP6lyuwWW7W+ZctusY9Nsa8t6jro2IKylu4TdJuq7bGcrAfcDrgNXAceDT3c1OS3Uu1+Cy3Q3nVbbrGPRrO+Q9Io6l3yeBBylu6evihKSlAOn3yS7npy0RcSIiXoqI3wOfo3f/5rUt1+Cy3Q3nW7brGPRrOeRd0islvWrqM7AGeHL2vXrKbmAofR4CHupiXto29Z85eQ+9+zevZbkGl+1uOd+yXfmbsy5UCUPeq7IEeFASFH/3L0bEI93NUnOSdgBvBxZKmgC2AFuBXZLuAp4Fbu1eDptrke+3S1pN0VRyBPhg1zI4ixqXa3DZLl0ny7anYTAzy0gdm3fMzOw8OeibmWXEQd/MLCM9/SB34cKFMTAw0O1sWB/bv3//cxf6jtzz4bJtZZqtXPd00B8YGGBsrGfnbrI+IOln3Tivy7aVabZy7eYdM7OMOOibmWXEQd/MLCM93aZv7RkYfvjM5yNb39XFnJjVw9T/mRz/vzjo9zgHdDPrJDfvmJllxEHfzCwjDvpmZhlx0Dczy4iDvplZRhz0zcwy4qBvZpYRB30zs4w46JuZZcRB38wsIw76ZmYZcdA3M8uIg76ZWUYc9M3MMnLOoC/pMkmPN/z8UtKHJX1M0s8b0m9u2OejksYlHZJ0Y0P62pQ2Lmm4rIsyM7PmzjmffkQcAlYDSJoH/Bx4EHg/cE9EfKpxe0mXA+uBNwF/CXxD0uvT6s8CNwATwD5JuyPiqQ5di5mZncNcX6JyPfCTiPiZpFbbrAN2RsRvgZ9KGgeuSuvGI+IZAEk707YO+mZmFZlrm/56YEfD8mZJBySNSJqf0pYBRxu2mUhprdKnkbRR0piksVOnTs0xe2Znu/POO1m8eDFXXHHFmbTJyUluuOEGgCskjU6VXxXuTU2QByRdObWPpCFJh9PPUEP630h6Iu1zr2apEZl1W9tBX9JFwLuB/0hJ9wGvo2j6OQ58emrTJrvHLOnTEyK2RcRgRAwuWrSo3eyZtXTHHXfwyCOPTEvbunUr119/PcCTwF5g6hnTTcCq9LORopwjaQGwBbia4s51S0NF57607dR+a0u8HLMLMpea/k3ADyPiBEBEnIiIlyLi98Dn+EMTzgSwomG/5cCxWdLNSnXdddexYMGCaWkPPfQQQ0NnKuvbgVvS53XAA1H4HnCJpKXAjcBoRExGxPPAKLA2rfuziPhuRATwQMOxzHrOXIL+BhqadlJhn/IeihoTwG5gvaSLJa2kqPn8ANgHrJK0Mt01rE/bmlXuxIkTLF1aFOGIOA4sTqvm2jy5LH2emX4WN11aL2jrQa6kP6HodfPBhuR/krSaoonmyNS6iDgoaRfFA9rTwKaIeCkdZzPwKDAPGImIgx26DrNOmWvzZFvNllA0XQLbAAYHB5tuY1a2toJ+RLwIvHpG2m2zbH83cHeT9D3Anjnm0azjlixZwvHjx4Ezd60n06rZmiffPiP9sZS+vMn2Zj3JI3ItS+9+97vZvn371OIQ8FD6vBu4PfXiuQZ4ITX/PAqskTQ/PcBdAzya1v1K0jWp187tDccy6zlz7advVjsbNmzgscce47nnnmP58uV8/OMfZ3h4mPe9730AVwAvALemzfcANwPjwIsUgxCJiElJn6R4NgXwiYiYTJ//DrgfeAXwn+nHrCc56Fvf27FjR9P0vXv3IunJiLh+Ki31wNnUbPuIGAFGmqSPUXx5mPU8N++YmWXENf3MDAw/fObzka3v6mJOzKwbXNM3M8uIg76ZWUYc9M3MMuKgb2aWET/INbPacweF9rmmb2aWEQd9M7OMOOibmWXEQd/MLCMO+mZmGXHQNzPLSFtBX9IRSU9IelzSWEpbIGlU0uH0e35Kl6R7JY1LOiDpyobjDKXtD0saanU+MzMrx1xq+u+IiNURMZiWh4G9EbEK2JuWoXiB+qr0sxG4D4ovCWALcDXFS9S3TH1RmJlZNS6keWcdMPXqoe3ALQ3pD0The8Al6XV0NwKjETEZEc8Do8DaCzi/mZnNUbsjcgP4uqQA/jW94HlJelUcEXFc0uK07TLgaMO+EymtVfo0kjZS3CFw6aWXzuFS6sOjB82sW9oN+tdGxLEU2Ecl/XiWbdUkLWZJn55QfKFsAxgcHDxrvZmZnb+2mnci4lj6fRJ4kKJN/kRqtiH9Ppk2nwBWNOy+HDg2S7qZmVXknEFf0islvWrqM7AGeBLYDUz1wBkCHkqfdwO3p1481wAvpGagR4E1kuanB7hrUpqZmVWkneadJcCDkqa2/2JEPCJpH7BL0l3As8Ctafs9wM3AOPAi8H6AiJiU9ElgX9ruExEx2bErMTOzczpn0I+IZ4C3NEn/H+D6JukBbGpxrBFgZO7ZNDOzTvCIXDOzjDjom5llxEHfzCwjDvpmZhlx0Dczy4iDvuXurz2DrOXEQd/MM8haRhz0zc7mGWStbznomxUzyO5PM7zCjBlkgY7NICtpTNLYqVOnOn0NZm1pd5ZNs37144i40jPIWi5c07fc/R94BlnLh4O+ZevXv/41pP8DnkHWcuHmHcvWiRMnAN4g6Ud4BlnLhIO+Zeu1r30twFMNXTUBzyBr/c3NO2ZmGXFN38xshoHhh898PrL1XV3MSee187rEFZK+KelpSQclfSilf0zSz9Pw9ccl3dywz0fTUPVDkm5sSF+b0sYlDTc7n5mZlaedmv5p4CMR8cP0rtz9kkbTunsi4lONG0u6HFgPvAn4S+Abkl6fVn8WuIGii9s+Sbsj4qlOXIiZmZ1bO69LPA5MjU78laSnaTLasME6YGdE/Bb4qaRxir7PAOPp9YtI2pm2ddDvUf18i2uWqzk9yJU0ALwV+H5K2pxmGxxpmGDKQ9XNzHpU20Ff0p8CXwE+HBG/pJhh8HXAaoo7gU9Pbdpk9zkNVY+IwYgYXLRoUbvZM7M+MzD88Jkf65y2eu9IejlFwP9CRHwVICJONKz/HPC1tDjbkHQPVTcz66J2eu8I+DzwdER8piF9acNm76EYvg7FUPX1ki6WtJJi7vEfUIxWXCVppaSLKB727u7MZZiZWTvaqelfC9wGPCHp8ZT298AGSaspmmiOAB8EiIiDknZRPKA9DWyKiJcAJG2mmJNkHjASEQc7eC1mZnYO7fTe+TbN2+P3zLLP3cDdTdL3zLafmZmVyyNyS+CujmbWqzz3jplZRhz0zcwy4qBvZpYRB30zs4w46JuZZcRB38wsIw76ZmYZcdA3M8uIB2fZefMgNLP6cU3fzCwjrumbWVf4TrE7XNM3M8uIa/ptcq3EzJqpW2xwTd/MLCOu6Vsl6lYbMutXldf0Ja2VdEjSuKThqs9vVgaX64JfZt77Kq3pS5oHfBa4geIF6vsk7Y6Ip6rMRzOuidr56uVybTZT1c07VwHjEfEMgKSdwDqK9+l23FQgdxCvhxr/e1Varrupxv9GXdNrFUpFRHUnk94LrI2ID6Tl24CrI2JzwzYbgY1p8TLgUIvDLQSeKzG7VemH66jzNbwmIhZdyAHaKdcpPaey3Q/XAPW9jpbluuqafrMXrE/71omIbcC2cx5IGouIwU5lrFv64Tr64Rou0DnLNeRVtvvhGqB/rqNR1Q9yJ4AVDcvLgWMV58Gs01yurTaqDvr7gFWSVkq6CFgP7K44D2ad5nJttVFp805EnJa0GXgUmAeMRMTB8zzcOW+Ta6IfrqMfruG8dbhcQ3/8PfvhGqB/ruOMSh/kmplZd3kaBjOzjDjom5llpJZBvx+GvEs6IukJSY9LGut2ftolaUTSSUlPNqQtkDQq6XD6Pb+beayrfijX4LLd62oX9BuGvN8EXA5skHR5d3N13t4REatr1g/4fmDtjLRhYG9ErAL2pmWbgz4r1+Cy3bNqF/RpGPIeEb8Dpoa8WwUi4lvA5IzkdcD29Hk7cEulmeoPLtddlkvZrmPQXwYcbVieSGl1E8DXJe1Pw/PrbElEHAdIvxd3OT911C/lGly2e1od59Nva8h7DVwbEcckLQZGJf041TQsT/1SrsFlu6fVsabfF0PeI+JY+n0SeJDi9r6uTkhaCpB+n+xyfuqoL8o1uGz3ujoG/doPeZf0SkmvmvoMrAGenH2vnrYbGEqfh4CHupiXuqp9uQaX7TqoXfNOCUPeu2EJ8KAkKP4NvhgRj3Q3S+2RtAN4O7BQ0gSwBdgK7JJ0F/AscGv3clhPfVKuwWW753kaBjOzjNSxecfMzM6Tg76ZWUYc9M3MMuKgb2aWEQd9M7OMOOibmWXEQd/MLCP/D4ONSYGv8NtjAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_weights(quant_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
